{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# node functor\n",
    "\n",
    "源码：`tvm/include/tvm/node/functor.h`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `NodeFunctor`\n",
    "\n",
    "`NodeFunctor` 的模板类，它用于根据第一个参数的类型动态分派函数。这个类在构造基于 AST/IR 节点类型的多态分派时非常有用。\n",
    "\n",
    "```c++\n",
    "/*!\n",
    " * \\brief A dynamically dispatched functor on the type of the first argument.\n",
    " *\n",
    " * This is a class that is useful to construct polymorphic dispatching\n",
    " * base on the AST/IR node's type.\n",
    " *\n",
    " * \\code\n",
    " *   NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;\n",
    " *   tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {\n",
    " *     return prefix + \"Add\";\n",
    " *   });\n",
    " *   tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {\n",
    " *     return prefix + \"IntImm\"\n",
    " *   });\n",
    " *\n",
    " *   Expr x = make_const(1);\n",
    " *   Expr y = x + x;\n",
    " *   // dispatch to IntImm, outputs \"MyIntImm\"\n",
    " *   LOG(INFO) << tostr(x, \"My\");\n",
    " *   // dispatch to IntImm, outputs \"MyAdd\"\n",
    " *   LOG(INFO) << tostr(y, \"My\");\n",
    " * \\endcode\n",
    " *\n",
    " * \\tparam FType function signiture\n",
    " *  This type if only defined for FType with function signature\n",
    " */\n",
    "template <typename FType>\n",
    "class NodeFunctor;\n",
    "\n",
    "template <typename R, typename... Args>\n",
    "class NodeFunctor<R(const ObjectRef& n, Args...)> {\n",
    " private:\n",
    "  /*! \\brief internal function pointer type */\n",
    "  typedef R (*FPointer)(const ObjectRef& n, Args...);\n",
    "  /*! \\brief refer to itself. */\n",
    "  using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;\n",
    "  /*! \\brief internal function table */\n",
    "  std::vector<FPointer> func_;\n",
    "\n",
    " public:\n",
    "  /*! \\brief the result type of this functor */\n",
    "  using result_type = R;\n",
    "  /*!\n",
    "   * \\brief Whether the functor can dispatch the corresponding Node\n",
    "   * \\param n The node to be dispatched\n",
    "   * \\return Whether dispatching function is registered for n's type.\n",
    "   */\n",
    "  bool can_dispatch(const ObjectRef& n) const {\n",
    "    uint32_t type_index = n->type_index();\n",
    "    return type_index < func_.size() && func_[type_index] != nullptr;\n",
    "  }\n",
    "  /*!\n",
    "   * \\brief invoke the functor, dispatch on type of n\n",
    "   * \\param n The Node argument\n",
    "   * \\param args The additional arguments\n",
    "   * \\return The result.\n",
    "   */\n",
    "  R operator()(const ObjectRef& n, Args... args) const {\n",
    "    ICHECK(can_dispatch(n)) << \"NodeFunctor calls un-registered function on type \"\n",
    "                            << n->GetTypeKey();\n",
    "    return (*func_[n->type_index()])(n, std::forward<Args>(args)...);\n",
    "  }\n",
    "  /*!\n",
    "   * \\brief set the dispatcher for type TNode\n",
    "   * \\param f The function to be set.\n",
    "   * \\tparam TNode the type of Node to be dispatched.\n",
    "   * \\return reference to self.\n",
    "   */\n",
    "  template <typename TNode>\n",
    "  TSelf& set_dispatch(FPointer f) {  // NOLINT(*)\n",
    "    uint32_t tindex = TNode::RuntimeTypeIndex();\n",
    "    if (func_.size() <= tindex) {\n",
    "      func_.resize(tindex + 1, nullptr);\n",
    "    }\n",
    "    ICHECK(func_[tindex] == nullptr) << \"Dispatch for \" << TNode::_type_key << \" is already set\";\n",
    "    func_[tindex] = f;\n",
    "    return *this;\n",
    "  }\n",
    "  /*!\n",
    "   * \\brief unset the dispatcher for type TNode\n",
    "   *\n",
    "   * \\tparam TNode the type of Node to be dispatched.\n",
    "   * \\return reference to self.\n",
    "   */\n",
    "  template <typename TNode>\n",
    "  TSelf& clear_dispatch() {  // NOLINT(*)\n",
    "    uint32_t tindex = TNode::RuntimeTypeIndex();\n",
    "    ICHECK_LT(tindex, func_.size()) << \"clear_dispatch: index out of range\";\n",
    "    func_[tindex] = nullptr;\n",
    "    return *this;\n",
    "  }\n",
    "};\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`NodeFunctor` 类有两个模板参数：`R` 表示返回值类型，`Args...` 表示函数的其他参数类型。它的内部包含函数指针类型 `FPointer`，用于存储指向特定类型的函数的指针。此外，还有名为 `func_` 的内部向量，用于存储这些函数指针。\n",
    "\n",
    "`NodeFunctor` 类提供了以下成员函数：\n",
    "\n",
    "1. `can_dispatch(const ObjectRef& n) const`：检查是否可以对给定的节点进行分派。如果节点的类型索引小于 `func_` 的大小且对应的函数指针不为空，则返回 `true`。\n",
    "2. `operator()(const ObjectRef& n, Args... args) const`：调用分派函数。首先检查是否可以对给定的节点进行分派，然后使用节点的类型索引从 `func_` 中获取相应的函数指针，并调用该函数。\n",
    "3. `set_dispatch(FPointer f)`：为特定类型的节点设置分派函数。首先计算节点类型的运行时类型索引，然后调整 `func_` 的大小以容纳新的函数指针（如果需要），并将新函数指针设置为给定的函数。\n",
    "4. `clear_dispatch()`：清除特定类型的节点的分派函数。首先计算节点类型的运行时类型索引，然后将对应的函数指针设置为 `nullptr`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;\n",
    "tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {\n",
    "    return prefix + \"Add\";\n",
    "});\n",
    "tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {\n",
    "    return prefix + \"IntImm\"\n",
    "});\n",
    "Expr x = make_const(1);\n",
    "Expr y = x + x;\n",
    "// dispatch to IntImm, outputs \"MyIntImm\"\n",
    "LOG(INFO) << tostr(x, \"My\");\n",
    "// dispatch to IntImm, outputs \"MyAdd\"\n",
    "LOG(INFO) << tostr(y, \"My\");\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了名为 `tostr` 的 `NodeFunctor` 对象，该对象用于将不同类型的节点转换为字符串。`NodeFunctor` 是模板类，接受函数签名作为参数，该函数签名表示如何将节点转换为字符串。\n",
    "\n",
    "在这段代码中，`NodeFunctor` 的函数签名为`std::string (const ObjectRef& n, std::string prefix)`，表示它接受 `ObjectRef` 类型的节点和字符串前缀，并返回字符串。\n",
    "\n",
    "接下来，使用 `set_dispatch` 方法为 `NodeFunctor` 设置两个分派函数。第一个分派函数处理 `Add` 类型的节点，它将节点转换为字符串并将前缀添加到字符串末尾。第二个分派函数处理 `IntImm` 类型的节点，它也将节点转换为字符串并将前缀添加到字符串末尾。\n",
    "\n",
    "然后，创建两个表达式 `x` 和 `y`，其中 `x` 是常量节点，值为 `1`。通过将这两个表达式传递给 `tostr` 对象，可以将其转换为字符串。由于 `x` 是 `Add` 类型的节点，因此调用 `tostr(x, \"My\")` 时，将调用第一个分派函数，输出结果为 `\"MyAdd\"`。同样，由于 `y` 是 `IntImm` 类型的节点，因此调用 `tostr(y, \"My\")` 时，将调用第二个分派函数，输出结果为 `\"MyIntImm\"`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `TVM_STATIC_IR_FUNCTOR`\n",
    "\n",
    "```c++\n",
    "#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName\n",
    "```\n",
    "这段代码是宏定义，用于生成一个名为 `__make_functor##_##ClsName` 的函数对象。这个函数对象是静态的（static）并且具有 `TVM_ATTRIBUTE_UNUSED` 属性，表示它不会被使用。\n",
    "\n",
    "解析如下：\n",
    "\n",
    "1. `#define` 是 C/C++ 预处理器指令，用于定义宏。\n",
    "2. `TVM_REG_FUNC_VAR_DEF(ClsName)` 是宏的名称，其中 `ClsName` 是参数，表示类名。\n",
    "3. `static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName` 是宏的定义部分。\n",
    "   - `static` 表示这是静态成员函数。\n",
    "   - `TVM_ATTRIBUTE_UNUSED` 是属性，表示该变量或函数未被使用，编译器不会发出警告。\n",
    "   - `auto&` 表示返回值类型为引用到自动类型的变量。\n",
    "   - `__make_functor##_##ClsName` 是生成的函数对象的名称，其中 `##` 是连接符，用于将两个字符串连接在一起。\n",
    "\n",
    "综上所述，这段代码的作用是定义名为 `__make_functor##_##ClsName` 的静态函数对象，该函数对象具有 `TVM_ATTRIBUTE_UNUSED` 属性，表示它不会被使用。\n",
    "\n",
    "```c++\n",
    "/*!\n",
    " * \\brief Useful macro to set NodeFunctor dispatch in a global static field.\n",
    " *\n",
    " * \\code\n",
    " *  // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.\n",
    " *  // vtable allows easy patch of new Node types, without changing\n",
    " *  // interface of ReprPrinter.\n",
    " *\n",
    " *  class ReprPrinter {\n",
    " *   public:\n",
    " *    std::ostream& stream;\n",
    " *    // the dispatch function.\n",
    " *    void print(Expr e) {\n",
    " *      const static FType& f = *vtable();\n",
    " *      f(e, this);\n",
    " *    }\n",
    " *\n",
    " *    using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;\n",
    " *    // function to return global function table\n",
    " *    static FType& vtable();\n",
    " *  };\n",
    " *\n",
    " *  // in cpp/cc file\n",
    " *  ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)\n",
    " *    static FType inst; return inst;\n",
    " *  }\n",
    " *\n",
    " *  TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)\n",
    " *  .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {\n",
    " *    auto* n = static_cast<const Add*>(ref.get());\n",
    " *    p->print(n->a);\n",
    " *    p->stream << '+'\n",
    " *    p->print(n->b);\n",
    " *  });\n",
    " *\n",
    " *\n",
    " * \\endcode\n",
    " *\n",
    " * \\param ClsName The name of the class\n",
    " * \\param FField The static function that returns a singleton of NodeFunctor.\n",
    " */\n",
    "#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \\\n",
    "  TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了名为 `TVM_STATIC_IR_FUNCTOR` 的宏，用于设置 `NodeFunctor` 的调度。`NodeFunctor` 是一种用于实现类似于访问者模式的函数对象。\n",
    "\n",
    "在这段代码中，`ReprPrinter` 类使用了 `NodeFunctor` 来实现打印功能。通过使用 `vtable`，可以轻松地为新的节点类型添加新的调度函数，而无需更改 `ReprPrinter` 接口。\n",
    "\n",
    "`ReprPrinter::FType& ReprPrinter::vtable()` 函数返回全局函数表。这个函数表是一个静态成员变量，它存储了 `NodeFunctor` 的实例。\n",
    "\n",
    "`TVM_STATIC_IR_FUNCTOR(ClsName, FField)` 宏的作用是将 `ClsName` 类的 `FField` 函数作为 `NodeFunctor` 的调度函数添加到全局函数表中。这样，当调用 `print` 方法时，会根据节点的类型选择相应的调度函数进行处理。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
